{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example of 1p online analysis using a Ring CNN + OnACID\n",
    "\n",
    "The demo shows how to perform online analysis on one photon data using a Ring CNN for extracting the background followed by processing using the OnACID algorithm. The algorithm relies on the usage a GPU to efficiently estimate and apply the background model so it is recommended to have access to a GPU when running this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import logging\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "logfile = None # Replace with a path if you want to log to a file\n",
    "logger = logging.getLogger('caiman')\n",
    "# Set to logging.INFO if you want much output, potentially much more output\n",
    "logger.setLevel(logging.WARNING)\n",
    "logfmt = logging.Formatter('%(relativeCreated)12d [%(filename)s:%(funcName)20s():%(lineno)s] [%(process)d] %(message)s')\n",
    "if logfile is not None:\n",
    "    handler = logging.FileHandler(logfile)\n",
    "else:\n",
    "    handler = logging.StreamHandler()\n",
    "handler.setFormatter(logfmt)\n",
    "logger.addHandler(handler)\n",
    "\n",
    "import caiman as cm\n",
    "import caiman.base.movies\n",
    "from caiman.source_extraction import cnmf as cnmf\n",
    "from caiman.utils.utils import download_demo\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import bokeh.plotting as bpl\n",
    "bpl.output_notebook()\n",
    "# Set play_movies to False if you want to disable play of movies, e.g. for remote-hosted Jupyter environments\n",
    "play_movies = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## First specify the data file(s) to be analyzed\n",
    "\n",
    "The `download_demo` method will download the file (if not already present) and store it inside your caiman_data/example_movies folder. You can specify any path to files you want to analyze."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fnames=download_demo('blood_vessel_10Hz.mat')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up some parameters\n",
    "\n",
    "Here we set up some parameters for specifying the ring model and running OnACID. We use the same `params` object as in batch processing with CNMF."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reuse_model = False                                                 # set to True to reuse an existing ring model\n",
    "path_to_model = None                                                # specify a pre-trained model here if needed \n",
    "gSig = (7, 7)                                                       # expected half size of neurons\n",
    "gnb = 2                                                             # number of background components for OnACID\n",
    "init_batch = 500                                                    # number of frames for initialization and training\n",
    "\n",
    "params_dict = {'fnames': fnames,\n",
    "               'var_name_hdf5': 'Y',                                # name of variable inside mat file where the data is stored\n",
    "               'fr': 10,                                            # frame rate (Hz)\n",
    "               'decay_time': 0.5,                                   # approximate length of transient event in seconds\n",
    "               'gSig': gSig,\n",
    "               'p': 0,                                              # order of AR indicator dynamics\n",
    "               'ring_CNN': True,                                    # SET TO TRUE TO USE RING CNN \n",
    "               'min_SNR': 2.65,                                     # minimum SNR for accepting new components\n",
    "               'SNR_lowest': 0.75,                                  # reject components with SNR below this value\n",
    "               'use_cnn': False,                                    # do not use CNN based test for components\n",
    "               'use_ecc': True,                                     # test eccentricity\n",
    "               'max_ecc': 2.625,                                    # reject components with eccentricity above this value\n",
    "               'rval_thr': 0.70,                                    # correlation threshold for new component inclusion\n",
    "               'rval_lowest': 0.25,                                 # reject components with corr below that value\n",
    "               'ds_factor': 1,                                      # spatial downsampling factor (increases speed but may lose some fine structure)\n",
    "               'nb': gnb,\n",
    "               'motion_correct': False,                             # Flag for motion correction\n",
    "               'init_batch': init_batch,                            # number of frames for initialization (presumably from the first file)\n",
    "               'init_method': 'bare',\n",
    "               'normalize': False,\n",
    "               'expected_comps': 1100,                               # maximum number of expected components used for memory pre-allocation (exaggerate here)\n",
    "               'sniper_mode': False,                                 # flag using a CNN to detect new neurons (o/w space correlation is used)\n",
    "               'dist_shape_update' : True,                           # flag for updating shapes in a distributed way\n",
    "               'min_num_trial': 5,                                   # number of candidate components per frame\n",
    "               'epochs': 3,                                          # number of total passes over the data\n",
    "               'stop_detection': True,                               # Run a last epoch without detecting new neurons  \n",
    "               'K': 50,                                              # initial number of components\n",
    "               'lr': 6e-4,\n",
    "               'lr_scheduler': [0.9, 6000, 10000],\n",
    "               'pct': 0.01,\n",
    "               'path_to_model': path_to_model,                       # where the ring CNN model is saved/loaded\n",
    "               'reuse_model': reuse_model                            # flag for re-using a ring CNN model          \n",
    "              }\n",
    "opts = cnmf.params.CNMFParams(params_dict=params_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Now run the Ring-CNN + CaImAn online algorithm (OnACID).\n",
    "\n",
    "The first ```initbatch``` frames are used for training the ring-CNN model. Once the model is trained the background is subtracted and the different is used for initialization purposes. The initialization method chosen here `bare` will only search for a small number of neurons and is mostly used to initialize the background components. Initialization with the full CNMF can also be used by choosing `cnmf`.\n",
    "\n",
    "We first create an `OnACID` object located in the module `online_cnmf` and we pass the parameters similarly to the case of batch processing. We then run the algorithm using the `fit_online` method. We then save the results inside\n",
    "the folder where the Ring_CNN model is saved."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_onacid = True\n",
    "\n",
    "if run_onacid:\n",
    "    cnm = cnmf.online_cnmf.OnACID(params=opts)\n",
    "    cnm.fit_online()\n",
    "    fld_name = os.path.dirname(cnm.params.ring_CNN['path_to_model'])\n",
    "    res_name_nm = os.path.join(fld_name, 'onacid_results_nm.hdf5')\n",
    "    cnm.save(res_name_nm)                # save initial results (without any postprocessing)\n",
    "else:\n",
    "    fld_name = os.path.dirname(path_to_model)\n",
    "    res_name = os.path.join(fld_name, 'onacid_results.hdf5')\n",
    "    cnm = cnmf.online_cnmf.load_OnlineCNMF(res_name)\n",
    "    cnm.params.data['fnames'] = fnames"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check speed\n",
    "Create some plots that show the speed per frame and cumulatively"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = 10             # plot every ds frames to make more manageable figures\n",
    "init_batch = 500\n",
    "dims, T = caiman.base.movies.get_file_size(fnames, var_name_hdf5='Y')\n",
    "T = np.array(T).sum()\n",
    "n_epochs = cnm.params.online['epochs']\n",
    "T_detect = 1e3*np.hstack((np.zeros(init_batch), cnm.t_detect))\n",
    "T_shapes = 1e3*np.hstack((np.zeros(init_batch), cnm.t_shapes))\n",
    "T_online = 1e3*np.hstack((np.zeros(init_batch), cnm.t_online)) - T_detect - T_shapes\n",
    "plt.figure()\n",
    "plt.stackplot(np.arange(len(T_detect))[::ds], T_online[::ds], T_detect[::ds], T_shapes[::ds],\n",
    "              colors=['tab:red', 'tab:purple', 'tab:brown'])\n",
    "plt.legend(labels=['process', 'detect', 'shapes'], loc=2)\n",
    "plt.title('Processing time allocation')\n",
    "plt.xlabel('Frame #')\n",
    "plt.ylabel('Processing time [ms]')\n",
    "max_val = 80\n",
    "plt.ylim([0, max_val]);\n",
    "plt.plot([init_batch, init_batch], [0, max_val], '--k')\n",
    "for i in range(n_epochs - 1):\n",
    "    plt.plot([(i+1)*T, (i+1)*T], [0, max_val], '--k')\n",
    "plt.xlim([0, n_epochs*T]);\n",
    "plt.savefig(os.path.join(fld_name, 'time_per_frame_ds.pdf'), bbox_inches='tight', pad_inches=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_batch = 500\n",
    "plt.figure()\n",
    "tc_init = cnm.t_init*np.ones(T*n_epochs)\n",
    "ds = 10\n",
    "#tc_mot = np.hstack((np.zeros(init_batch), np.cumsum(T_motion)/1000))\n",
    "tc_prc = np.cumsum(T_online)/1000#np.hstack((np.zeros(init_batch), ))\n",
    "tc_det = np.cumsum(T_detect)/1000#np.hstack((np.zeros(init_batch), ))\n",
    "tc_shp = np.cumsum(T_shapes)/1000#np.hstack((np.zeros(init_batch), ))\n",
    "plt.stackplot(np.arange(len(tc_init))[::ds], tc_init[::ds], tc_prc[::ds], tc_det[::ds], tc_shp[::ds],\n",
    "              colors=['g', 'tab:red', 'tab:purple', 'tab:brown'])\n",
    "plt.legend(labels=['initialize', 'process', 'detect', 'shapes'], loc=2)\n",
    "plt.title('Processing time allocation')\n",
    "plt.xlabel('Frame #')\n",
    "plt.ylabel('Processing time [s]')\n",
    "max_val = (tc_prc[-1] + tc_det[-1] + tc_shp[-1] + cnm.t_init)*1.05\n",
    "for i in range(n_epochs - 1):\n",
    "    plt.plot([(i+1)*T, (i+1)*T], [0, max_val], '--k')\n",
    "plt.xlim([0, n_epochs*T]);\n",
    "plt.ylim([0, max_val])\n",
    "plt.savefig(os.path.join(fld_name, 'time_cumulative_ds.pdf'), bbox_inches='tight', pad_inches=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Cost of estimating model and running first epoch: {:.2f}s'.format(tc_prc[T] + tc_det[T] + tc_shp[T] + tc_init[T]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Do some initial plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# first compute background summary images\n",
    "images = cm.load(fnames, var_name_hdf5='Y', subindices=slice(None, None, 2))\n",
    "cn_filter, pnr = cm.summary_images.correlation_pnr(images, gSig=3, swap_dim=False) # change swap dim if output looks weird, it is a problem with tiffile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(15, 7))\n",
    "plt.subplot(1,2,1); plt.imshow(cn_filter); plt.colorbar()\n",
    "plt.subplot(1,2,2); plt.imshow(pnr); plt.colorbar()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnm.estimates.plot_contours_nb(img=cn_filter, idx=cnm.estimates.idx_components, line_color='white', thr=0.3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## View components\n",
    "\n",
    "Now inspect the components extracted by OnACID. Note that if single pass was used then several components would be non-zero only for the part of the time interval indicating that they were detected online by OnACID.\n",
    "\n",
    "Note that if you get data rate error you can start Jupyter notebooks using:\n",
    "'jupyter lab --ZMQChannelsWebsocketConnection.iopub_data_rate_limit=1.0e10'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnm.estimates.nb_view_components(img=cn_filter, denoised_color='red')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load ring model to filter the data\n",
    "Filter the data with the learned Ring CNN model and a create memory mapped file with the background subtracted data. We will use this to run the quality tests and screen for false positive components."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "save_file = True\n",
    "if save_file:\n",
    "    from caiman.utils.nn_models import create_LN_model\n",
    "    model_LN = create_LN_model(images, shape=opts.data['dims'] + (1,), n_channels=opts.ring_CNN['n_channels'],\n",
    "                               width=opts.ring_CNN['width'], use_bias=opts.ring_CNN['use_bias'], gSig=gSig[0],\n",
    "                               use_add=opts.ring_CNN['use_add'])\n",
    "    model_LN.load_weights(cnm.params.ring_CNN['path_to_model'])\n",
    "\n",
    "    # Load the data in batches and save them\n",
    "\n",
    "    m = []\n",
    "    saved_files = []\n",
    "    batch_length = 256\n",
    "    for i in range(0, T, batch_length):\n",
    "        images = cm.load(fnames, var_name_hdf5='Y', subindices=slice(i, i + batch_length))\n",
    "        images_filt = np.squeeze(model_LN.predict(np.expand_dims(images, axis=-1)))\n",
    "        temp_file = os.path.join(fld_name, 'pfc_back_removed_' + format(i, '05d') + '.h5')\n",
    "        saved_files.append(temp_file)\n",
    "        m = cm.movie(np.maximum(images - images_filt, 0))\n",
    "        m.save(temp_file)\n",
    "else:\n",
    "    saved_files = glob.glob(os.path.join(fld_name, 'pfc_back_removed_*'))\n",
    "    saved_files.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fname_mmap = cm.save_memmap([saved_files], order='C', border_to_0=0)\n",
    "Yr, dims, T = cm.load_memmap(fname_mmap)\n",
    "images_mmap = Yr.T.reshape((T,) + dims, order='F')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Merge components"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnm.params.merging['merge_thr'] = 0.7\n",
    "cnm.estimates.c1 = np.zeros(cnm.estimates.A.shape[-1])\n",
    "cnm.estimates.bl = np.zeros(cnm.estimates.A.shape[-1])\n",
    "cnm.estimates.neurons_sn = np.zeros(cnm.estimates.A.shape[-1])\n",
    "cnm.estimates.g = None\n",
    "cnm.estimates.merge_components(Yr, cnm.params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate components and compare again\n",
    "\n",
    "We run the component evaluation tests to screen for false positive components."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "cnm.params.quality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnm.estimates.evaluate_components(imgs=images_mmap, params=cnm.params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnm.estimates.plot_contours_nb(img=cn_filter, idx=cnm.estimates.idx_components, line_color='white')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnm.estimates.nb_view_components(idx=cnm.estimates.idx_components, img=cn_filter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compare against CNMF-E results\n",
    "\n",
    "We download the results of CNMF-E on the same dataset and compare."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnmfe_results = download_demo('online_vs_offline.npz')\n",
    "locals().update(np.load(cnmfe_results, allow_pickle=True))\n",
    "A_patch_good = A_patch_good.item()\n",
    "estimates_gt = cnmf.estimates.Estimates(A=A_patch_good, C=C_patch_good, dims=dims)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "maxthr=0.01\n",
    "cnm.estimates.A_thr=None\n",
    "cnm.estimates.threshold_spatial_components(maxthr=maxthr)\n",
    "estimates_gt.A_thr=None\n",
    "estimates_gt.threshold_spatial_components(maxthr=maxthr*10)\n",
    "min_size = np.pi*(gSig[0]/1.5)**2\n",
    "max_size = np.pi*(gSig[0]*1.5)**2\n",
    "ntk = cnm.estimates.remove_small_large_neurons(min_size_neuro=min_size, max_size_neuro=2*max_size)\n",
    "gtk = estimates_gt.remove_small_large_neurons(min_size_neuro=min_size, max_size_neuro=2*max_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m1, m2, nm1, nm2, perf = cm.base.rois.register_ROIs(estimates_gt.A_thr[:, estimates_gt.idx_components],\n",
    "                                                  cnm.estimates.A_thr[:, cnm.estimates.idx_components],\n",
    "                                                  dims, align_flag=False, thresh_cost=.7, plot_results=True,\n",
    "                                                  Cn=cn_filter, enclosed_thr=None)[:-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Print performance results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for k, v in perf.items():\n",
    "    print(k + ':', '%.4f' % v, end='  ')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_name = os.path.join(fld_name, 'onacid_results.hdf5')\n",
    "cnm.save(res_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make some plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.lines as mlines\n",
    "lp, hp = np.nanpercentile(cn_filter, [5, 98])\n",
    "A_onacid = cnm.estimates.A_thr.toarray().copy()\n",
    "A_onacid /= A_onacid.max(0)\n",
    "\n",
    "A_TP = estimates_gt.A[:, m1].toarray() #cnm.estimates.A[:, cnm.estimates.idx_components[m2]].toarray()\n",
    "A_TP = A_TP.reshape(dims + (-1,), order='F').transpose(2,0,1)\n",
    "A_FN = estimates_gt.A[:, nm1].toarray()\n",
    "A_FN = A_FN.reshape(dims + (-1,), order='F').transpose(2,0,1)\n",
    "A_FP = A_onacid[:,cnm.estimates.idx_components[nm2]]\n",
    "A_FP = A_FP.reshape(dims + (-1,), order='F').transpose(2,0,1)\n",
    "\n",
    "\n",
    "plt.figure(figsize=(15, 12))\n",
    "plt.imshow(cn_filter, vmin=lp, vmax=hp, cmap='viridis')\n",
    "plt.colorbar();\n",
    "\n",
    "for aa in A_TP:\n",
    "    plt.contour(aa, [0.05], colors='k');\n",
    "for aa in A_FN:\n",
    "    plt.contour(aa, [0.05], colors='r');\n",
    "for aa in A_FP:\n",
    "    plt.contour(aa, [0.25], colors='w');\n",
    "cl = ['k', 'r', 'w']\n",
    "lb = ['both', 'CNMF-E only', 'ring CNN only']\n",
    "day = [mlines.Line2D([], [], color=cl[i], label=lb[i]) for i in range(3)]\n",
    "plt.legend(handles=day, loc=3)\n",
    "plt.axis('off');\n",
    "plt.margins(0, 0);\n",
    "plt.savefig(os.path.join(fld_name, 'ring_CNN_contours_gSig_3.pdf'), bbox_inches='tight', pad_inches=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A_rej = cnm.estimates.A[:, cnm.estimates.idx_components_bad].toarray()\n",
    "A_rej = A_rej.reshape(dims + (-1,), order='F').transpose(2,0,1)\n",
    "plt.figure(figsize=(15, 15))\n",
    "plt.imshow(cn_filter, vmin=lp, vmax=hp, cmap='viridis')\n",
    "plt.title('Rejected Components')\n",
    "for aa in A_rej:\n",
    "    plt.contour(aa, [0.05], colors='w');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Show the learned filters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from caiman.utils.nn_models import create_LN_model\n",
    "model_LN = create_LN_model(images, shape=opts.data['dims'] + (1,), n_channels=opts.ring_CNN['n_channels'],\n",
    "                           width=opts.ring_CNN['width'], use_bias=opts.ring_CNN['use_bias'], gSig=gSig[0],\n",
    "                           use_add=opts.ring_CNN['use_add'])\n",
    "model_LN.load_weights(cnm.params.ring_CNN['path_to_model'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "W = model_LN.get_weights()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "plt.subplot(2,2,1); plt.imshow(np.squeeze(W[0][:,:,:,0])); plt.colorbar(); plt.title('Ring Kernel 1')\n",
    "plt.subplot(2,2,2); plt.imshow(np.squeeze(W[0][:,:,:,1])); plt.colorbar(); plt.title('Ring Kernel 2')\n",
    "plt.subplot(2,2,3); plt.imshow(np.squeeze(W[-1][:,:,0])); plt.colorbar(); plt.title('Multiplicative Layer 1')\n",
    "plt.subplot(2,2,4); plt.imshow(np.squeeze(W[-1][:,:,1])); plt.colorbar(); plt.title('Multiplicative Layer 2');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make a movie"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m1 = cm.load(fnames, var_name_hdf5='Y')  # original data\n",
    "m2 = cm.load(fname_mmap)  # background subtracted data\n",
    "m3 = m1 - m2  # estimated background\n",
    "m4 = cm.movie(cnm.estimates.A[:,cnm.estimates.idx_components].dot(cnm.estimates.C[cnm.estimates.idx_components])).reshape(dims + (T,)).transpose(2,0,1)\n",
    "              # estimated components\n",
    "\n",
    "nn = 0.01\n",
    "mm = 1 - nn/4   # normalize movies by quantiles\n",
    "m1 = (m1 - np.quantile(m1[:1000], nn))/(np.quantile(m1[:1000], mm) - np.quantile(m1[:1000], nn))\n",
    "m2 = (m2 - np.quantile(m2[:1000], nn))/(np.quantile(m2[:1000], mm) - np.quantile(m2[:1000], nn))\n",
    "m3 = (m3 - np.quantile(m3[:1000], nn))/(np.quantile(m3[:1000], mm) - np.quantile(m3[:1000], nn))\n",
    "m4 = (m4 - np.quantile(m4[:1000], nn))/(np.quantile(m4[:1000], mm) - np.quantile(m4[:1000], nn))\n",
    "\n",
    "m = cm.concatenate((cm.concatenate((m1.transpose(0,2,1), m3.transpose(0,2,1)), axis=2),\n",
    "                    cm.concatenate((m2.transpose(0,2,1), m4), axis=2)), axis=1)\n",
    "\n",
    "if play_movies:\n",
    "    m[:3000].play(magnification=2, q_min=1, plot_text=True,\n",
    "                  save_movie=True, movie_name=os.path.join(fld_name, 'movie_pytorch.avi'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "caiman_pytorch2",
   "language": "python",
   "name": "caiman_pytorch2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
